// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.

using System;
using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using ManagedBass;
using osu.Framework.Utils;
using osu.Framework.Audio.Callbacks;
using osu.Framework.Extensions;
using osu.Framework.Logging;

namespace osu.Framework.Audio.Track
{
    /// <summary>
    /// Processes audio sample data such that it can then be consumed to generate waveform plots of the audio.
    /// </summary>
    public class Waveform : IDisposable
    {
        /// <summary>
        /// <see cref="Point"/>s are initially generated to a 1ms resolution to cover most use cases.
        /// </summary>
        private const float resolution = 0.001f;

        /// <summary>
        /// The data stream is iteratively decoded to provide this many points per iteration so as to not exceed BASS's internal buffer size.
        /// </summary>
        private const int points_per_iteration = 1000;

        /// <summary>
        /// FFT1024 gives ~40hz accuracy.
        /// </summary>
        private const DataFlags fft_samples = DataFlags.FFT1024;

        /// <summary>
        /// Number of bins generated by the FFT. Must correspond to <see cref="fft_samples"/>.
        /// </summary>
        private const int fft_bins = 512;

        /// <summary>
        /// Minimum frequency for low-range (bass) frequencies. Based on lower range of bass drum fallout.
        /// </summary>
        private const float low_min = 20;

        /// <summary>
        /// Minimum frequency for mid-range frequencies. Based on higher range of bass drum fallout.
        /// </summary>
        private const float mid_min = 100;

        /// <summary>
        /// Minimum frequency for high-range (treble) frequencies.
        /// </summary>
        private const float high_min = 2000;

        /// <summary>
        /// Maximum frequency for high-range (treble) frequencies. A sane value.
        /// </summary>
        private const float high_max = 12000;

        private int channels;
        private Point[] points = Array.Empty<Point>();

        private readonly CancellationTokenSource cancelSource = new CancellationTokenSource();

        private readonly Task readTask;

        private Stream? data;

        /// <summary>
        /// Constructs a new <see cref="Waveform"/> from provided audio data.
        /// </summary>
        /// <param name="data">
        /// The sample data stream.
        /// The <see cref="Waveform"/> will take ownership of this stream and dispose it when done reading track data.
        /// If null, an empty waveform is constructed.
        /// </param>
        public Waveform(Stream? data)
        {
            this.data = data;

            var token = cancelSource.Token;

            readTask = Task.Run(() =>
            {
                if (data == null)
                    return;

                // for the time being, this code cannot run if there is no bass device available.
                if (Bass.CurrentDevice < 0)
                {
                    Logger.Log("Failed to read waveform as no bass device is available.");
                    return;
                }

                FileCallbacks fileCallbacks = new FileCallbacks(new DataStreamFileProcedures(data));

                const int bytes_per_sample = 4;

                int decodeStream = Bass.CreateStream(StreamSystem.NoBuffer, BassFlags.Decode | BassFlags.Float, fileCallbacks.Callbacks, fileCallbacks.Handle);

                if (decodeStream == 0)
                {
                    logBassError("could not create stream");
                    return;
                }

                float[]? sampleBuffer = null;

                try
                {
                    if (!Bass.ChannelGetInfo(decodeStream, out ChannelInfo info))
                    {
                        logBassError("could not retrieve channel information");
                        return;
                    }

                    long length = Bass.ChannelGetLength(decodeStream);

                    if (length < 0)
                    {
                        logBassError("could not retrieve channel length");
                        return;
                    }

                    // Each "point" is generated from a number of samples, each sample contains a number of channels
                    int samplesPerPoint = (int)(info.Frequency * resolution * info.Channels);

                    int bytesPerPoint = samplesPerPoint * bytes_per_sample;

                    int pointCount = (int)(length / bytesPerPoint);

                    points = new Point[pointCount];

                    // Each iteration pulls in several samples
                    int bytesPerIteration = bytesPerPoint * points_per_iteration;

                    sampleBuffer = ArrayPool<float>.Shared.Rent(bytesPerIteration / bytes_per_sample);

                    int pointIndex = 0;

                    // Read sample data
                    while (length > 0)
                    {
                        length = Bass.ChannelGetData(decodeStream, sampleBuffer, bytesPerIteration);

                        if (length < 0 && Bass.LastError != Errors.Ended)
                        {
                            logBassError("could not retrieve sample data");
                            return;
                        }

                        int samplesRead = (int)(length / bytes_per_sample);

                        // Each point is composed of multiple samples
                        for (int i = 0; i < samplesRead && pointIndex < pointCount; i += samplesPerPoint)
                        {
                            token.ThrowIfCancellationRequested();

                            // We assume one or more channels.
                            // For non-stereo tracks, we'll use the single track for both amplitudes.
                            // For anything above two tracks we'll use the first and second track.
                            Debug.Assert(info.Channels >= 1);
                            int secondChannelIndex = info.Channels > 1 ? 1 : 0;

                            // Channels are interleaved in the sample data (data[0] -> channel0, data[1] -> channel1, data[2] -> channel0, etc)
                            // samplesPerPoint assumes this interleaving behaviour
                            var point = new Point();

                            for (int j = i; j < i + samplesPerPoint; j += info.Channels)
                            {
                                // Find the maximum amplitude for each channel in the point
                                point.AmplitudeLeft = Math.Max(point.AmplitudeLeft, Math.Abs(sampleBuffer[j]));
                                point.AmplitudeRight = Math.Max(point.AmplitudeRight, Math.Abs(sampleBuffer[j + secondChannelIndex]));
                            }

                            // BASS may provide unclipped samples, so clip them ourselves
                            point.AmplitudeLeft = Math.Min(1, point.AmplitudeLeft);
                            point.AmplitudeRight = Math.Min(1, point.AmplitudeRight);

                            points[pointIndex++] = point;
                        }
                    }

                    if (!Bass.ChannelSetPosition(decodeStream, 0))
                    {
                        logBassError("could not reset channel position");
                        return;
                    }

                    length = Bass.ChannelGetLength(decodeStream);

                    if (length < 0)
                    {
                        logBassError("could not retrieve channel length");
                        return;
                    }

                    // Read FFT data
                    float[] bins = new float[fft_bins];
                    int currentPoint = 0;
                    long currentByte = 0;

                    while (length > 0)
                    {
                        length = Bass.ChannelGetData(decodeStream, bins, (int)fft_samples);

                        if (length < 0 && Bass.LastError != Errors.Ended)
                        {
                            logBassError("could not retrieve FFT data");
                            return;
                        }

                        currentByte += length;

                        float lowIntensity = computeIntensity(info, bins, low_min, mid_min);
                        float midIntensity = computeIntensity(info, bins, mid_min, high_min);
                        float highIntensity = computeIntensity(info, bins, high_min, high_max);

                        // In general, the FFT function will read more data than the amount of data we have in one point
                        // so we'll be setting intensities for all points whose data fits into the amount read by the FFT
                        // We know that each data point required sampleDataPerPoint amount of data
                        for (; currentPoint < points.Length && currentPoint * bytesPerPoint < currentByte; currentPoint++)
                        {
                            token.ThrowIfCancellationRequested();

                            var point = points[currentPoint];
                            point.LowIntensity = lowIntensity;
                            point.MidIntensity = midIntensity;
                            point.HighIntensity = highIntensity;
                            points[currentPoint] = point;
                        }
                    }

                    channels = info.Channels;
                }
                finally
                {
                    if (!Bass.StreamFree(decodeStream))
                        logBassError("failed to free decode stream");

                    fileCallbacks.Dispose();

                    data.Dispose();
                    this.data = data = null;

                    if (sampleBuffer != null)
                        ArrayPool<float>.Shared.Return(sampleBuffer);
                }
            }, token);

            void logBassError(string reason) => Logger.Log($"BASS failure while reading waveform: {reason} ({Bass.LastError})");
        }

        private float computeIntensity(ChannelInfo info, float[] bins, float startFrequency, float endFrequency)
        {
            int startBin = (int)(fft_bins * 2 * startFrequency / info.Frequency);
            int endBin = (int)(fft_bins * 2 * endFrequency / info.Frequency);

            startBin = Math.Clamp(startBin, 0, bins.Length);
            endBin = Math.Clamp(endBin, 0, bins.Length);

            float value = 0;
            for (int i = startBin; i < endBin; i++)
                value += bins[i];
            return value;
        }

        /// <summary>
        /// Creates a new <see cref="Waveform"/> containing a specific number of data points by selecting the average value of each sampled group.
        /// </summary>
        /// <param name="pointCount">The number of points the resulting <see cref="Waveform"/> should contain.</param>
        /// <param name="cancellationToken">The token to cancel the task.</param>
        /// <returns>An async task for the generation of the <see cref="Waveform"/>.</returns>
        public async Task<Waveform> GenerateResampledAsync(int pointCount, CancellationToken cancellationToken = default)
        {
            ArgumentOutOfRangeException.ThrowIfNegative(pointCount);

            if (pointCount == 0)
                return new Waveform(null);

            await readTask.ConfigureAwait(false);

            return await Task.Run(() =>
            {
                var generatedPoints = new Point[pointCount];

                float pointsPerGeneratedPoint = (float)points.Length / pointCount;

                // Determines at which width (relative to the resolution) our smoothing filter is truncated.
                // Should not effect overall appearance much, except when the value is too small.
                // A gaussian contains almost all its mass within its first 3 standard deviations,
                // so a factor of 3 is a very good choice here.
                const int kernel_width_factor = 3;

                int kernelWidth = (int)(pointsPerGeneratedPoint * kernel_width_factor) + 1;

                float[] filter = new float[kernelWidth + 1];

                for (int i = 0; i < filter.Length; ++i)
                {
                    if (cancellationToken.IsCancellationRequested)
                        return new Waveform(null);

                    filter[i] = (float)Blur.EvalGaussian(i, pointsPerGeneratedPoint);
                }

                // we're keeping two indices: one for the original (fractional!) point we're generating based on,
                // and one (integral) for the points we're going to be generating.
                // it's important to avoid adding by pointsPerGeneratedPoint in a loop, as floating-point errors can result in
                // drifting of the computed values in either direction - we multiply the generated index by pointsPerGeneratedPoint instead.
                float originalPointIndex = 0;
                int generatedPointIndex = 0;

                while (generatedPointIndex < pointCount)
                {
                    if (cancellationToken.IsCancellationRequested)
                        return new Waveform(null);

                    int startIndex = (int)originalPointIndex - kernelWidth;
                    int endIndex = (int)originalPointIndex + kernelWidth;

                    var point = new Point();
                    float totalWeight = 0;

                    for (int j = startIndex; j < endIndex; j++)
                    {
                        if (j < 0 || j >= points.Length) continue;

                        float weight = filter[Math.Abs(j - startIndex - kernelWidth)];
                        totalWeight += weight;

                        point.AmplitudeLeft += weight * points[j].AmplitudeLeft;
                        point.AmplitudeRight += weight * points[j].AmplitudeRight;
                        point.LowIntensity += weight * points[j].LowIntensity;
                        point.MidIntensity += weight * points[j].MidIntensity;
                        point.HighIntensity += weight * points[j].HighIntensity;
                    }

                    if (totalWeight > 0)
                    {
                        // Means
                        point.AmplitudeLeft /= totalWeight;
                        point.AmplitudeRight /= totalWeight;
                        point.LowIntensity /= totalWeight;
                        point.MidIntensity /= totalWeight;
                        point.HighIntensity /= totalWeight;
                    }

                    generatedPoints[generatedPointIndex] = point;

                    generatedPointIndex += 1;
                    originalPointIndex = generatedPointIndex * pointsPerGeneratedPoint;
                }

                return new Waveform(null)
                {
                    points = generatedPoints,
                    channels = channels
                };
            }, cancellationToken).ConfigureAwait(false);
        }

        /// <summary>
        /// Gets all the points represented by this <see cref="Waveform"/>.
        /// </summary>
        public Point[] GetPoints() => GetPointsAsync().GetResultSafely();

        /// <summary>
        /// Gets all the points represented by this <see cref="Waveform"/>.
        /// </summary>
        public async Task<Point[]> GetPointsAsync()
        {
            await readTask.ConfigureAwait(false);
            return points;
        }

        /// <summary>
        /// Gets the number of channels represented by each <see cref="Point"/>.
        /// </summary>
        public int GetChannels() => GetChannelsAsync().GetResultSafely();

        /// <summary>
        /// Gets the number of channels represented by each <see cref="Point"/>.
        /// </summary>
        public async Task<int> GetChannelsAsync()
        {
            await readTask.ConfigureAwait(false);
            return channels;
        }

        #region Disposal

        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }

        private bool isDisposed;

        protected virtual void Dispose(bool disposing)
        {
            if (isDisposed)
                return;

            isDisposed = true;

            cancelSource.Cancel();
            cancelSource.Dispose();

            // Try disposing the stream again in case the task was not started.
            data?.Dispose();
            data = null;
        }

        #endregion

        /// <summary>
        /// Represents a singular point of data in a <see cref="Waveform"/>.
        /// </summary>
        public struct Point
        {
            /// <summary>
            /// The amplitude of the left channel.
            /// </summary>
            public float AmplitudeLeft;

            /// <summary>
            /// The amplitude of the right channel.
            /// </summary>
            public float AmplitudeRight;

            /// <summary>
            /// Unnormalised total intensity of the low-range (bass) frequencies.
            /// </summary>
            public float LowIntensity;

            /// <summary>
            /// Unnormalised total intensity of the mid-range frequencies.
            /// </summary>
            public float MidIntensity;

            /// <summary>
            /// Unnormalised total intensity of the high-range (treble) frequencies.
            /// </summary>
            public float HighIntensity;
        }
    }
}
